# -*- coding: utf-8 -*-
from pyhanlp import *
import zipfile
import os
from pyhanlp.static import download, remove_file, HANLP_DATA_PATH


def test_data_path():
	"""
    获取测试数据路径，位于$root/data/test，根目录由配置文件指定。
    :return:
    """
	data_path = os.path.join(HANLP_DATA_PATH, 'test')
	if not os.path.isdir(data_path):
		os.mkdir(data_path)
	return data_path


# 验证是否存在语料库，如果没有自动下载
def ensure_data(data_name, data_url):
	root_path = test_data_path()
	dest_path = os.path.join(root_path, data_name)
	if os.path.exists(dest_path):
		return dest_path

	if data_url.endswith('.zip'):
		dest_path += '.zip'
	download(data_url, dest_path)
	if data_url.endswith('.zip'):
		with zipfile.ZipFile(dest_path, "r") as archive:
			archive.extractall(root_path)
		remove_file(dest_path)
		dest_path = dest_path[:-len('.zip')]
	return dest_path


## 指定 PKU 语料库
PKU98 = ensure_data("pku98", "http://file.hankcs.com/corpus/pku98.zip")
PKU199801 = os.path.join(PKU98, '199801.txt')
PKU199801_TRAIN = os.path.join(PKU98, '199801-train.txt')
PKU199801_TEST = os.path.join(PKU98, '199801-test.txt')
POS_MODEL = os.path.join(PKU98, 'pos.bin')
NER_MODEL = os.path.join(PKU98, 'ner.bin')

## ===============================================
## 以下开始 CRF 命名实体识别
PerceptronSegmenter = JClass('com.hankcs.hanlp.model.perceptron.PerceptronSegmenter')
CRFNERecognizer = JClass('com.hankcs.hanlp.model.crf.CRFNERecognizer')
AbstractLexicalAnalyzer = JClass('com.hankcs.hanlp.tokenizer.lexical.AbstractLexicalAnalyzer')
Utility = JClass('com.hankcs.hanlp.model.perceptron.utility.Utility')
PerceptronPOSTagger = JClass('com.hankcs.hanlp.model.perceptron.PerceptronPOSTagger')


def train(corpus, model):
	# 零参数的构造函数代表加载配置文件默认的模型，必须用null None 与之区分。
	recognizer = CRFNERecognizer()  # 空白
	recognizer.train(corpus, model)
	return recognizer


def test(recognizer, test):
	analyzer = AbstractLexicalAnalyzer(PerceptronSegmenter(), PerceptronPOSTagger(), recognizer)
	print(test)
	print(analyzer.analyze(test))
	scores = Utility.evaluateNER(recognizer, PKU199801_TEST)
	Utility.printNERScore(scores)


def loadArticle(fileName):
	'''
    读取原始自用数据集的测试文章
    :param fileName: 文件名
    :return: 处理之后的文章
    '''
	# 我们需要将其空格去掉
	with open(fileName, encoding='utf-8') as file:
		# 按行读取
		test_article = []
		for line in file.readlines():
			# 去除空格，以及换行符
			line = line.replace("<content>", "")
			line = line.replace("</content>", "")
			line = line.replace(" ", "")
			line = line.replace("　", "")
			line = line.strip()
			test_article.append(line)
	return test_article


if __name__ == '__main__':
	recognizer = train(PKU199801_TRAIN, NER_MODEL)
	data = loadArticle("test.txt")

	for i in range(len(data)):
		test(recognizer, data[i])
